{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zheyuanzhang/anaconda3/envs/FRS/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from utils import *\n",
    "import pandas as pd\n",
    "from transformers import BertTokenizer, BertModel\n",
    "from sklearn.decomposition import PCA\n",
    "import torch\n",
    "from torch_geometric.data import HeteroData\n",
    "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmark Construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load in necessary data.\n",
    "df_fndds = pd.read_csv('../processed_data/fndds.csv')\n",
    "df_nutrition = pd.read_csv('../processed_data/food_tagging.csv')\n",
    "df_user = pd.read_csv('../processed_data/user_tagging.csv')\n",
    "df_user_food = pd.read_csv('../processed_data/food_user.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Extract relevant columns\n",
    "columns_to_extract = ['SEQN', 'gender', 'race', 'household_income', 'education', 'age_group']\n",
    "df_user_extracted = df_user[columns_to_extract]\n",
    "\n",
    "# Step 2: One-Hot Encode the categorical columns\n",
    "df_user_vector = pd.get_dummies(df_user_extracted, columns=['gender', 'race', 'household_income', 'education', 'age_group'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Concatenate Ingredient Descriptions by Food\n",
    "ingredient_concat = df_fndds.groupby('food_id')['ingredient_desc'].apply(lambda x: ' '.join(x)).reset_index()\n",
    "df_fndds = df_fndds.drop(columns=['ingredient_desc']).drop_duplicates(subset='food_id')\n",
    "food_fndds = pd.merge(df_fndds, ingredient_concat, on='food_id')\n",
    "\n",
    "# Step 2: Concatenate Food Description, WWEIA Description, and Ingredient Description\n",
    "food_fndds['combined_desc'] = food_fndds[['food_desc', 'WWEIA_desc', 'ingredient_desc']].agg(' '.join, axis=1)\n",
    "\n",
    "# Step 3: Convert Combined Descriptions into BERT Embeddings\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "model = BertModel.from_pretrained('bert-base-uncased')\n",
    "\n",
    "def get_bert_embedding(text):\n",
    "    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)\n",
    "    outputs = model(**inputs)\n",
    "    return outputs.last_hidden_state.mean(dim=1).squeeze().detach().numpy()\n",
    "\n",
    "food_fndds['bert_embedding'] = food_fndds['combined_desc'].apply(get_bert_embedding)\n",
    "\n",
    "# Convert the embeddings to a DataFrame\n",
    "embeddings = np.vstack(food_fndds['bert_embedding'].values)\n",
    "embeddings_df = pd.DataFrame(embeddings, index=food_fndds['food_id'])\n",
    "\n",
    "# Step 4: Apply PCA to Reduce Dimensionality to 50\n",
    "pca = PCA(n_components=50)\n",
    "reduced_embeddings = pca.fit_transform(embeddings)\n",
    "reduced_embeddings_df = pd.DataFrame(reduced_embeddings, index=food_fndds['food_id'])\n",
    "\n",
    "# Result: Table with food_id as index and PCA-reduced embeddings as columns\n",
    "reduced_embeddings_df.columns = [f'PC{i+1}' for i in range(50)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Standardize and Normalize Nutrition Vectors\n",
    "nutrition_vectors = df_nutrition.iloc[:, 1:17].values  # Assuming the first column is food_id\n",
    "\n",
    "# Standardize (z-score normalization)\n",
    "scaler = StandardScaler()\n",
    "standardized_nutrition = scaler.fit_transform(nutrition_vectors)\n",
    "\n",
    "# Normalize (min-max scaling)\n",
    "min_max_scaler = MinMaxScaler()\n",
    "normalized_nutrition = min_max_scaler.fit_transform(standardized_nutrition)\n",
    "\n",
    "# Convert to DataFrame\n",
    "nutrition_df = pd.DataFrame(normalized_nutrition, columns=df_nutrition.columns[1:17])\n",
    "nutrition_df.insert(0, 'food_id', df_nutrition['food_id'])\n",
    "\n",
    "# Step 2: Merge with Reduced Embeddings\n",
    "df_nutrition_vector = pd.merge(nutrition_df, reduced_embeddings_df, on='food_id', how='left')\n",
    "df_nutrition_vector.fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here we created two tables with only health tags. This will be used later for node features for health loss downstream. \n",
    "df_user_copy = df_user.copy()\n",
    "nutrition_columns = df_nutrition.columns[17:].tolist()\n",
    "# Create new columns in df_user_copy if they are missing\n",
    "for nutrition in nutrition_columns:\n",
    "    user_col = 'user_' + nutrition\n",
    "    if user_col not in df_user_copy.columns:\n",
    "        df_user_copy[user_col] = 0\n",
    "\n",
    "# Reorder columns\n",
    "ordered_columns = ['SEQN'] + ['user_' + col for col in nutrition_columns]\n",
    "df_user_tags = df_user_copy[ordered_columns]\n",
    "df_nutrition_tags = df_nutrition[['food_id'] + nutrition_columns]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_main = pd.merge(df_user_food, df_user, on='SEQN', how='left')\n",
    "df_main = pd.merge(df_main, df_nutrition, on='food_id', how='left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Filter 1: We only leave the adults here. This becuase nutrition suggestions differ for children and adults.\n",
    "\"\"\"\n",
    "df_main = df_main[df_main['age'] >= 18]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "nutrition_list = ['calorie', 'carb', 'protein', 'sugar', 'fiber', 'saturated_fat', 'cholesterol',\n",
    "                  'sodium', 'folic_acid', 'calcium', 'iron', 'potassium', 'vitamin_b12', 'vitamin_c', \n",
    "                  'phosphorus', 'vitamin_d']\n",
    "macro_nutrition_list = ['calorie', 'carb', 'protein', 'fiber', 'saturated_fat', 'cholesterol', 'sugar']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark_construction(df_main, nutrition_list):\n",
    "    # Here we match the nutrition tags with the user tags for statistical analysis.\n",
    "    df_anal = pd.DataFrame(df_main[['SEQN', 'food_id']])\n",
    "    df_anal['total_match'] = 0\n",
    "    df_anal['total_opposite'] = 0\n",
    "    for nutrition in nutrition_list:\n",
    "        user_high = 'user_high_' + nutrition\n",
    "        user_low = 'user_low_' + nutrition\n",
    "        food_high = 'high_' + nutrition\n",
    "        food_low = 'low_' + nutrition\n",
    "\n",
    "        if user_high in df_main.columns:\n",
    "            df_anal[nutrition+'_high_matching'] = (df_main[user_high] & df_main[food_high]).astype(int)\n",
    "            df_anal[nutrition+'_high_opposite'] = (df_main[user_high] & df_main[food_low]).astype(int)\n",
    "            df_anal['total_match'] += df_anal[nutrition+'_high_matching']\n",
    "            df_anal['total_opposite'] += df_anal[nutrition+'_high_opposite']\n",
    "        if user_low in df_main.columns:\n",
    "            df_anal[nutrition+'_low_matching'] = (df_main[user_low] & df_main[food_low]).astype(int)\n",
    "            df_anal[nutrition+'_low_opposite'] = (df_main[user_low] & df_main[food_high]).astype(int)\n",
    "            df_anal['total_match'] += df_anal[nutrition+'_low_opposite']\n",
    "            df_anal['total_opposite'] += df_anal[nutrition+'_low_matching']\n",
    "\n",
    "    df_anal['clean_score'] = df_anal['total_match'] - df_anal['total_opposite']\n",
    "\n",
    "    # we only keep users who have consumed more than 10 healthy dishes.\n",
    "    df_summary = df_anal[df_anal['clean_score'] > 0].groupby('SEQN')[['total_match', 'total_opposite', 'clean_score']].sum()\n",
    "    df_summary['dish_count'] = df_anal[df_anal['clean_score'] > 0].groupby('SEQN')['food_id'].count()\n",
    "    df_summary = df_summary.loc[(df_summary['dish_count'] > 10)].sort_values('clean_score', ascending=False)\n",
    "\n",
    "    # Retrieve the valid user list. \n",
    "    valid_SEQN_set = df_summary.index\n",
    "    df_anal = df_anal[df_anal['SEQN'].isin(valid_SEQN_set)]\n",
    "\n",
    "    # Remap the rest features back.\n",
    "    df_main = pd.merge(df_main[['SEQN', 'food_id']].drop_duplicates(), df_anal, on=['SEQN', 'food_id'], how='right')\n",
    "\n",
    "    return df_main, df_summary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Benchmark "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_macro, df_macro_summary = benchmark_construction(df_main, macro_nutrition_list)\n",
    "df_all, df_all_summary = benchmark_construction(df_main, nutrition_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "122009\n",
      "6769\n",
      "314224\n",
      "8170\n"
     ]
    }
   ],
   "source": [
    "# The number of healthy user-food interactions\n",
    "print(df_macro_summary['dish_count'].sum())\n",
    "# The unique food items\n",
    "print(df_macro['food_id'].nunique())\n",
    "# The total number of user-food interactions\n",
    "print(len(df_macro))\n",
    "# The number of unique users\n",
    "print(df_macro['SEQN'].nunique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "207949\n",
      "7516\n",
      "488223\n",
      "13282\n"
     ]
    }
   ],
   "source": [
    "# The number of healthy user-food interactions\n",
    "print(df_all_summary['dish_count'].sum())\n",
    "# The unique food items\n",
    "print(df_all['food_id'].nunique())\n",
    "# The total number of user-food interactions\n",
    "print(len(df_all))\n",
    "# The number of unique users\n",
    "print(df_all['SEQN'].nunique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_all.to_csv('../processed_data/raw_benchmark.csv', index=False)\n",
    "df_macro.to_csv('../processed_data/raw_benchmark_macro.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def graph_construction(df_main, df_user_vector, df_nutrition_vector, df_user_tags, df_nutrition_tags, output_path, macro_only=False):\n",
    "\n",
    "    # Get unique user and food IDs\n",
    "    user_ids = df_main['SEQN'].unique()\n",
    "    food_ids = df_main['food_id'].unique()\n",
    "\n",
    "    # Create node features\n",
    "    user_features = torch.tensor(user_ids, dtype=torch.long)\n",
    "    food_features = torch.tensor(food_ids, dtype=torch.long)\n",
    "\n",
    "    # Create edge indices\n",
    "    user_indices = torch.tensor(df_main['SEQN'].map(lambda x: np.where(user_ids == x)[0][0]), dtype=torch.long)\n",
    "    food_indices = torch.tensor(df_main['food_id'].map(lambda x: np.where(food_ids == x)[0][0]), dtype=torch.long)\n",
    "    # Create edge indices\n",
    "    edge_index = torch.stack([user_indices, food_indices], dim=0)\n",
    "\n",
    "    # Create edge labels\n",
    "    clean_scores = df_main['clean_score']\n",
    "    clean_scores = df_main[df_main['clean_score'] > 0]['clean_score']\n",
    "    \n",
    "    # Create edge labels\n",
    "    edge_labels = torch.stack([user_indices[clean_scores.index], food_indices[clean_scores.index]], dim=0)\n",
    "    \n",
    "    graph = HeteroData()\n",
    "    user_feature_vectors = df_user_vector.drop_duplicates().set_index('SEQN').loc[user_ids].values\n",
    "    food_feature_vectors = df_nutrition_vector.drop_duplicates().set_index('food_id').loc[food_ids].values\n",
    "    print(user_feature_vectors.shape)\n",
    "    graph['user'].x = torch.tensor(user_feature_vectors, dtype=torch.float)\n",
    "    graph['food'].x = torch.tensor(food_feature_vectors, dtype=torch.float)\n",
    "\n",
    "    graph['user'].node_id = user_features\n",
    "    graph['food'].node_id = food_features\n",
    "    # Add num_nodes attribute to every node type\n",
    "    graph['user'].num_nodes = len(user_ids)\n",
    "    graph['food'].num_nodes = len(food_ids)\n",
    "    graph[('user', 'eats', 'food')].edge_index = edge_index\n",
    "    graph[('user', 'eats', 'food')].edge_label_index = edge_labels\n",
    "\n",
    "    macro_nutrition_length = 14  # 7 macro nutrition types, each with 'low' and 'high'\n",
    "    user_tag_features = []\n",
    "    for user_id in user_ids:\n",
    "        if user_id in df_user_tags['SEQN'].values:\n",
    "            user_tag_vector = torch.tensor(df_user_tags[df_user_tags['SEQN'] == user_id].iloc[0, 1:].values, dtype=torch.float)\n",
    "            if macro_only:\n",
    "                user_tag_vector = user_tag_vector[:macro_nutrition_length]\n",
    "            user_tag_features.append(user_tag_vector)\n",
    "        else:\n",
    "            raise ValueError(f'User {user_id} does not have any health tags.')\n",
    "    graph['user'].tags = torch.stack(user_tag_features)\n",
    "\n",
    "    food_tag_features = []\n",
    "    for food_id in food_ids:\n",
    "        if food_id in df_nutrition_tags['food_id'].values:\n",
    "            food_tag_vector = torch.tensor(df_nutrition_tags[df_nutrition_tags['food_id'] == food_id].iloc[0, 1:].values, dtype=torch.float)\n",
    "            if macro_only:\n",
    "                food_tag_vector = food_tag_vector[:macro_nutrition_length]\n",
    "            food_tag_features.append(food_tag_vector)\n",
    "        else:\n",
    "            raise ValueError(f'Food {food_id} does not have any nutrition tags.')\n",
    "    graph['food'].tags = torch.stack(food_tag_features)\n",
    "\n",
    "    torch.save(graph, output_path)\n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(13282, 38)\n",
      "(8170, 38)\n"
     ]
    }
   ],
   "source": [
    "graph_all = graph_construction(df_all, df_user_vector, df_nutrition_vector, df_user_tags, df_nutrition_tags, '../processed_data/benchmark_all.pt')\n",
    "graph_macro = graph_construction(df_macro, df_user_vector, df_nutrition_vector, df_user_tags, df_nutrition_tags, '../processed_data/benchmark_macro.pt', macro_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "FRS",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
